import json
import os
import re

import pandas as pd
import torch
from tqdm import tqdm
from unsloth import FastLanguageModel, is_bfloat16_supported, get_chat_template

# 配置参数
max_seq_length = 40960
load_in_4bit = False  # 必须关闭4bit量化
dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16
data_path = "../data/research_data/RQ1/test_data.jsonl"
model_name = "/root/autodl-tmp/model/Qwen3-8B"
result_file_name = "result/qwen3_8B_model_thinking_outputs.xlsx"
eval_result_file = "result/qwen3_8B_thinking_metrics.txt"


# 加载本地数据集
test_data = []
# 读取jsonl文件
with open(data_path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line)
        test_data.append(data)


# 模型输出处理函数
def model_output(messages, model, tokenizer):
    FastLanguageModel.for_inference(model)  # 启用原生推理速度快2倍
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,  # Must add for generation
        return_tensors="pt",
    ).to("cuda")
    outputs = model.generate(inputs, max_new_tokens=max_seq_length, use_cache=True)
    final_answer = tokenizer.batch_decode(outputs)

    # 提取 think 部分和最终输出
    think_pattern = r"<think>(.*?)</think>"
    think_match = re.search(think_pattern, final_answer[-1], re.DOTALL)
    think_content = think_match.group(1).strip() if think_match else "无思考过程"

    output_pattern = r"</think>(.*?)<\|im_end\|>"
    output_match = re.search(output_pattern, final_answer[-1], re.DOTALL)
    output_content = output_match.group(1).strip() if output_match else "模型无输出"

    # 处理 JSON 代码块
    output_lines = output_content.splitlines()
    if len(output_lines) >= 2 and output_lines[0].strip() == "```json" and output_lines[-1].strip() == "```":
        output_content = "\n".join(output_lines[1:-1])

    print(f"think_content=======================")
    print(think_content)
    print(f"output_content=======================")
    print(output_content)
    print("==============================================")
    return think_content, output_content


# 加载训练后的模型
print("模型加载开始=======================")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    max_seq_length=max_seq_length,
)
tokenizer = get_chat_template(
    tokenizer,
    chat_template="qwen-2.5",
)
print("模型加载完成=======================")

# 检查 Excel 文件是否存在，如果存在则加载已有数据
processed_repo_PR_number = {}
if os.path.exists(result_file_name):
    existing_df = pd.read_excel(result_file_name)
    # 构建已处理的 repo_name 和 pr_number 映射
    for _, row in existing_df.iterrows():
        repo_name = row['repo']
        pr_number = row['pr_number']
        if repo_name not in processed_repo_PR_number:
            processed_repo_PR_number[repo_name] = set()
        processed_repo_PR_number[repo_name].add(pr_number)
else:
    existing_df = pd.DataFrame(
        columns=["pr_number", "repo", 'diff_snippet', "prompt", "模型思考", "模型输出", "need_check"])
    # 如果文件不存在，直接创建新文件
    existing_df.to_excel(result_file_name, index=False, engine="openpyxl")

# 使用 tqdm 显示进度条
with pd.ExcelWriter(result_file_name, engine="openpyxl", mode="a", if_sheet_exists="replace") as writer:
    for item in tqdm(test_data, desc="处理进度", unit="条"):
        pr_number = item["pr_number"]
        repo = item['repo_name']
        need_check = item["need_check"]
        diff_snippet = item["diff_snippet"]

        if repo in processed_repo_PR_number and pr_number in processed_repo_PR_number[repo]:
            print(f"{repo}中的第{pr_number}号 已存在，跳过处理=======================")
            continue

        messages = [
            {"role": "system",
             "content": (
                 "You are a code review decision assistant for the OpenHarmony project. "
                 "Your job is: given a diff snippet, decide whether this change requires a human code review.\n"
                 "You must output only the lowercase word true or false. "
                 "Do not output anything else (no explanation, no punctuation, no line breaks, no code blocks).\n"
                 "Repeat: the final output must be only true or false."), },
            {"role": "user",
             "content": diff_snippet },#+ " \n /no_think"},
        ]

        print(f"{repo}中的第{pr_number}号 模型输出开始=======================")
        think_content, output_content = model_output(messages, model, tokenizer)

        # 将新记录添加到 DataFrame
        new_row = {
            "pr_number": pr_number,
            "repo": repo,
            'diff_snippet': diff_snippet,
            "prompt": messages,
            "模型思考": think_content,
            "模型输出": output_content,
            "need_check": need_check
        }
        existing_df = pd.concat([existing_df, pd.DataFrame([new_row])], ignore_index=True)

        # 写入 Excel 文件
        existing_df.to_excel(writer, index=False)
        writer.book.save(result_file_name)  # 强制写入文件
        print(f"{repo}中的第{pr_number}号 模型输出结束=======================")

print(f"输出结果已保存到 `{result_file_name}`")
# 接着处理完成后，读取result_file_name中的模型输出和need_check字段，计算准确率等指标
print("开始计算模型准确率等指标=======================")
# 读取Excel文件
df = pd.read_excel(result_file_name)

# 初始化计数器
true_positive = 0  # 模型输出true，实际也是true
false_positive = 0  # 模型输出true，实际是false
false_negative = 0  # 模型输出false，实际是true
true_negative = 0  # 模型输出false，实际也是false

# 遍历每一行数据
for index, row in df.iterrows():
    model_output = str(row['模型输出']).strip().lower()
    need_check = str(row['need_check']).strip().lower()

    # 确保模型输出是true或false
    if model_output == 'true':
        if need_check == 'true':
            true_positive += 1
        else:
            false_positive += 1
    elif model_output == 'false':
        if need_check == 'true':
            false_negative += 1
        else:
            true_negative += 1

# 计算准确率、精确率、召回率和F1分数
accuracy = (true_positive + true_negative) / (true_positive + false_positive + false_negative + true_negative)
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

# 打印结果
print(f"准确率 (Accuracy): {accuracy:.4f}")
print(f"精确率 (Precision): {precision:.4f}")
print(f"召回率 (Recall): {recall:.4f}")
print(f"F1分数 (F1-Score): {f1_score:.4f}")

# 保存结果到文件
with open(eval_result_file, "w") as f:
    f.write(f"准确率 (Accuracy): {accuracy:.4f}\n")
    f.write(f"精确率 (Precision): {precision:.4f}\n")
    f.write(f"召回率 (Recall): {recall:.4f}\n")
    f.write(f"F1分数 (F1-Score): {f1_score:.4f}\n")

print("指标计算完成，结果已保存到 result/metrics.txt =======================")
